Chapter 3: Sampling the Imaginary¶
[1]:
import random
from typing import Sequence
import arviz as az
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import pandas as pd
import plotly
import plotly.graph_objects as go
import plotly.io as pio
from scipy.stats import gaussian_kde
pd.options.plotting.backend = "plotly"
seed = 84735
pio.templates.default = "plotly_white"
rng = jax.random.PRNGKey(seed)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Code¶
Code 3.1¶
[2]:
p_positive_vampire = 0.95
p_positive_mortal = 0.01
p_vampire = 0.001
p_positive = p_positive_vampire * p_vampire + p_positive_mortal * (1 - p_vampire)
p_vampire_positive = p_positive_vampire * p_vampire / p_positive
p_vampire_positive
[2]:
0.08683729433272395
Code 3.2¶
[3]:
def calculate_posterior(W: int, L: int, prior: Sequence[float], grid_size: int):
grid = jnp.linspace(0, 1, grid_size)
likelihood = jnp.exp(dist.Binomial(total_count=W + L, probs=grid).log_prob(W))
raw_posterior = prior * likelihood
posterior = raw_posterior / raw_posterior.sum()
return posterior
W = 6
L = 3
grid_size = 1_000
prior = jnp.full(grid_size, 1)
p_grid = jnp.linspace(0, 1, grid_size)
posterior = calculate_posterior(W, L, prior, grid_size)
Code 3.3¶
[4]:
samples = p_grid[
dist.Categorical(probs=posterior).sample(rng, (10_000,))
]
Code 3.4¶
[5]:
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=jnp.arange(10_000),
y=samples,
mode="markers",
line={"color": "rgba(0, 0, 255, 0.2)"},
)
)
Code 3.5¶
[6]:
az.plot_density({"": samples}, hdi_prob=1)
[6]:
array([[<AxesSubplot: >]], dtype=object)
Code 3.6¶
[7]:
posterior[p_grid < 0.5].sum()
[7]:
DeviceArray(0.17187458, dtype=float32)
Code 3.7¶
[8]:
jnp.sum(samples < 0.5) / samples.shape[0]
[8]:
DeviceArray(0.1756, dtype=float32)
Code 3.8¶
[9]:
jnp.sum(jnp.logical_and(samples > 0.5, samples < 0.75)) / samples.shape[0]
[9]:
DeviceArray(0.5978, dtype=float32)
Code 3.9¶
[10]:
jnp.quantile(samples, 0.8)
[10]:
DeviceArray(0.7617617, dtype=float32)
Code 3.10¶
[11]:
jnp.quantile(samples, jnp.array([0.1, 0.9]))
[11]:
DeviceArray([0.45245245, 0.8128128 ], dtype=float32)
Code 3.11¶
[12]:
posterior = calculate_posterior(W=3, L=0, prior=jnp.full(1_000, 1), grid_size=1_000)
samples = p_grid[dist.Categorical(probs=posterior).sample(rng, (10_000,))]
Code 3.12¶
[13]:
def percentile_interval(samples, prob):
prob = min(prob, 1 - prob)
return jnp.quantile(samples, jnp.array([prob / 2, 1 - prob / 2]))
percentile_interval(samples, 0.5)
[13]:
DeviceArray([0.7067067, 0.9319319], dtype=float32)
Code 3.13¶
[14]:
numpyro.diagnostics.hpdi(samples, prob=0.5)
[14]:
array([0.8398398, 0.998999 ], dtype=float32)
Code 3.14¶
[15]:
p_grid[jnp.argmax(posterior)]
[15]:
DeviceArray(1., dtype=float32)
Code 3.15¶
[16]:
samples[jnp.argmax(gaussian_kde(samples, bw_method=0.01)(samples))]
[16]:
DeviceArray(0.985986, dtype=float32)
Code 3.16¶
[17]:
display(samples.mean())
jnp.median(samples)
DeviceArray(0.8006291, dtype=float32)
[17]:
DeviceArray(0.8408408, dtype=float32)
Code 3.17¶
[18]:
jnp.sum(jnp.abs(0.5 - p_grid) * posterior)
[18]:
DeviceArray(0.31287518, dtype=float32)
Code 3.18¶
[19]:
loss = jax.vmap(lambda d: jnp.sum(jnp.abs(d - p_grid) * posterior))(p_grid)
display(pd.DataFrame(loss, index=p_grid).plot())
Code 3.19¶
[20]:
p_grid[jnp.argmin(loss)]
[20]:
DeviceArray(0.8408408, dtype=float32)
Code 3.20¶
[21]:
jnp.exp(dist.Binomial(total_count=2, probs=0.7).log_prob(jnp.arange(3)))
[21]:
DeviceArray([0.08999996, 0.42000008, 0.48999974], dtype=float32)
Code 3.21¶
[22]:
with numpyro.handlers.seed(rng_seed=seed):
dummy_w = numpyro.sample("dummy_w", dist.Binomial(total_count=2, probs=0.7))
dummy_w
[22]:
DeviceArray(2, dtype=int32, weak_type=True)
Code 3.22¶
[23]:
with numpyro.handlers.seed(rng_seed=seed):
dummy_w = numpyro.sample(
"dummy_w", dist.Binomial(total_count=2, probs=0.7), sample_shape=(10,)
)
dummy_w
[23]:
DeviceArray([0, 1, 2, 1, 1, 2, 1, 2, 1, 2], dtype=int32, weak_type=True)
Code 3.23¶
[24]:
with numpyro.handlers.seed(rng_seed=seed):
dummy_w = numpyro.sample(
"dummy_w", dist.Binomial(total_count=2, probs=0.7), sample_shape=(100_000,)
)
dummy_w = pd.DataFrame(dummy_w, columns=["dummy_w"])
dummy_w["freq"] = 1
dummy_w.groupby("dummy_w").sum() / 100_000
[24]:
| freq | |
|---|---|
| dummy_w | |
| 0 | 0.09004 |
| 1 | 0.42109 |
| 2 | 0.48887 |
Code 3.24¶
[25]:
with numpyro.handlers.seed(rng_seed=seed):
dummy_w = numpyro.sample(
"dummy_w", dist.Binomial(total_count=9, probs=0.7), sample_shape=(100_000,)
)
dummy_w = pd.DataFrame(dummy_w, columns=["dummy_w"])
dummy_w.plot(kind="hist")
Code 3.25¶
[26]:
w = dist.Binomial(total_count=9, probs=0.6).sample(jax.random.PRNGKey(seed), (10_000,))
pd.DataFrame(w).plot(kind="hist")
Code 3.26¶
[27]:
w = dist.Binomial(total_count=9, probs=samples).sample(
jax.random.PRNGKey(seed),
)
pd.DataFrame(w).plot(kind="hist")
Hard¶
3H1¶
[28]:
# fmt: off
births_1 = [
1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0,
0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0,
0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
]
births_2 = [
0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1,
1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
]
births = jnp.array([births_1, births_2])
[29]:
grid_size = 1_000
p_grid = jnp.linspace(0, 1, grid_size)
prior = [0.5] * grid_size
likelhihood = jnp.exp(
dist.Binomial(total_count=births.size, probs=p_grid).log_prob(births.sum())
)
raw_posterior = likelhihood * jnp.array(prior)
posterior = raw_posterior / raw_posterior.sum()
map_p = p_grid[jnp.argmax(posterior)]
print(f"p={map_p:.2%} maximizes the posterior probability.")
p=55.46% maximizes the posterior probability.
3H2¶
[30]:
posterior_samples = p_grid[dist.Categorical(probs=posterior).sample(rng, (10_000,))]
print(f"50% HDPI: {numpyro.diagnostics.hpdi(posterior_samples, prob=0.5)}")
print(f"89% HDPI: {numpyro.diagnostics.hpdi(posterior_samples, prob=0.89)}")
print(f"97% HDPI: {numpyro.diagnostics.hpdi(posterior_samples, prob=0.97)}")
50% HDPI: [0.5275275 0.5745746]
89% HDPI: [0.4994995 0.6096096]
97% HDPI: [0.47847846 0.6286286 ]
3H3¶
[31]:
posterior_predictive_samples = dist.Binomial(
total_count=births.size, probs=posterior_samples
).sample(rng)
print(
f"Posterior predictive distribution of number of boys has mean {posterior_predictive_samples.mean():.0f} "
f"vs observation of {births.sum()}: we're evaluating model against training data"
)
pd.DataFrame(posterior_predictive_samples, columns=["n_boys"]).plot(kind="hist")
Posterior predictive distribution of number of boys has mean 111 vs observation of 111: we're evaluating model against training data
3H4¶
[32]:
posterior_predictive_samples = dist.Binomial(
total_count=births.shape[1], probs=posterior_samples
).sample(rng)
print(
f"Posterior predictive distribution of first born sons has mean {posterior_predictive_samples.mean():.0f} "
f"vs obersvation of {births[0].sum()}; still reasonable but not as good as purely 'in-sample'."
)
pd.DataFrame(posterior_predictive_samples, columns=["n_first_born_boys"]).plot(
kind="hist"
)
Posterior predictive distribution of first born sons has mean 55 vs obersvation of 51; still reasonable but not as good as purely 'in-sample'.
3H5¶
[33]:
posterior_predictive_samples = dist.Binomial(
total_count=jnp.logical_not(births[0]).sum(), probs=posterior_samples
).sample(rng)
print(
f"PPD of boys with big sisters of {posterior_predictive_samples.mean():.0f} "
f"is completely out of line with observations of {births[1].sum()}: we didn't model "
"the correlation between first and second birth that's present in our dataset."
)
pd.DataFrame(posterior_predictive_samples, columns=["n_boys_with_big_sister"]).plot(
kind="hist"
)
PPD of boys with big sisters of 27 is completely out of line with observations of 60: we didn't model the correlation between first and second birth that's present in our dataset.
[ ]: